#!/usr/bin/env python

import sys,pickle,os,glob,traceback,sqlite3,scipy
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from optparse import OptionParser
import pygtk
pygtk.require("2.0") 
import gobject,gtk,gtk.glade
from pylab import *
import ocrolib
from ocrolib import docproc,dbhelper

def TODO():
    print """
    accept-selection button (turns everything else to _)
    redo-selection button (turns selection to _)
    reject-selection button (turns selection to #)
    misseg-selection button (turns selection to ~)
    cluster-classifier using output of manual classification
    special chars: #=reject, _=unclassified, ~=missegmentation
    put confidence into database and display
    nnet+exception
    centering on largest component only
    sort by size, similarity, frequency
    FIXME unify the two classification loops
    cluster in-place
    cluster+label selection (separate dialog)
    train new classifier
    train classifier on selection
    freeze characters (keep in current window despite different labels)
    """

parser = OptionParser(usage="""
%prog [options] [input.db]

Edit character- and cluster-databases.

OCRopus isolated character training data is stored in sqlite3 databases
with a simple structure.  Isolated characters are (usually) stored in
a table called "chars", while clusters are stored in a table called
"clusters".  This application lets you browse and label these characters
and clusters.

You need to specify the correct table name for your database using the "-t"
flag.

Updates to the database are transacted and happen immediately; you don't need
to save.  Even concurrent editing and viewing is possible.  There is no
undo, but it's usually easy to manually undo changes.

There is a variety of sorting, training, and classification options available.
Particularly useful is the ability to sort by size and aspect ratio to quickly 
identify characters that have implausible sizes or aspect ratios.  Once you have
a character model, you can also sort by classifier confidence and focus correction
efforts on the characters with the least confidence.

Eventually, you will be able to cluster and train right from this interface, but
for now, use the external ocropus-ctrain and ocropus-cluster-db programs.

""")
parser.add_option("-m","--model",help="model used for classification",default="default.cmodel")
parser.add_option("-v","--verbose",help="verbose",action="store_true")
parser.add_option("-t","--table",help="which table to edit",default=None)
# use this with old (unflipped) cmodels trained from Python
parser.add_option("-F","--flip",help="flip characters before handing to classifier",default=1,action="store_false")
parser.add_option("-B","--batchsize",help="maximum number of characters loaded for display",type="int",default=10000)
parser.add_option("-s","--select",help="initial character to select",default=None)
(options,args) = parser.parse_args()
args = ocrolib.expand_args(args)

ion()

### misc utility functions

def detuple(item):
    """Return the first non-list/tuple element of the
    argument, recursively."""
    while 1:
        if type(item) is tuple:
            item = item[0]
            continue
        if type(item) is list:
            item = item[0]
            continue
        return item

def numpy2pixbuf(a,limit=40):
    """Convert a numpy array to a pixbuf."""
    r = max(a.shape)
    scaled = 0
    if r>limit:
        a = array(a,'f')
        a = scipy.ndimage.interpolation.zoom(a,limit/float(r),order=1)
        scaled = 1
    data = zeros(list(a.shape)+[3],'B')
    data[:,:,0] = 255*a
    data[:,:,1] = 255*a
    data[:,:,2] = 255*a
    if scaled:
        data[:3,:3,:] = 0
        data[:3,:3,1] = 255
    return gtk.gdk.pixbuf_new_from_array(data,gtk.gdk.COLORSPACE_RGB,8)

### load the cluster images

print "opening table"
fname = "clusters.db"
if len(args)>0: fname = args[0]
if not os.path.exists(fname):
    print fname,"not found"
    sys.exit(1)

db = sqlite3.connect(fname,timeout=600.0)
db.row_factory = dbhelper.DbRow
db.text_factory = sqlite3.OptimizedUnicode
db.execute("pragma synchronous=0")

if options.table is None:
    names = [r[0] for r in db.execute("select name from sqlite_master where type='table'")]
    options.table = names[0]
    print "defaulting to table",options.table

dbhelper.charcolumns(db,options.table)
db.commit()

def get_clusters(cls,limit=options.batchsize,table=options.table):
    print "get_clusters",cls
    cur = db.execute("select * from %s where cls=? order by count,cost desc limit %s"%(table,limit),[cls])
    clusters = []
    for c in cur:
        image = dbhelper.blob2image(c.image)/255.0
        cluster = ocrolib.Record(id=c.id,
                                file=c.file,
                                image=image,
                                cost=c.cost,
                                rel=c.rel,
                                cls=c.cls,
                                count=c.count,
                                classes=c.classes,
                                cluster=c.cluster)
        if cluster.cls is None:
            cluster.cls = "_"
        clusters.append(cluster)
    print "got",len(clusters),"clusters"
    return clusters

def get_classes(table=options.table):
    print "getting class list"
    return [r[0] for r in db.execute("select distinct(cls) from %s order by cls"%table)]
    print "done"

def compute_combolist():
    """Compute the combolist from the current charlist."""
    global charlist,combolist
    select = class_selector.get_active_text()
    charlist = sorted(get_classes())
    charlist = ["_"]+charlist
    combolist = gtk.ListStore(str)
    for char in charlist:
        combolist.append([char])
    class_selector.set_model(combolist)
    class_selector.set_text_column(0)
    if select in charlist:
        which = charlist.index(select)
        class_selector.set_active(which)

def set_store(target_cls,sortfun=None):
    """Set the store for the target class."""
    global grid
    grid = gtk.ListStore(gtk.gdk.Pixbuf,
                         str,
                         gobject.TYPE_PYOBJECT)
    rownum = 0
    selected = get_clusters(cls=target_cls)
    if sortfun is not None:
        print "sorting",len(selected),"with",sortfun
        selected = sortfun(selected)
    for cluster in selected:
        pixbuf = numpy2pixbuf(1.0-cluster.image)
        row = [pixbuf,cluster.cls,cluster]
        grid.append(row)
        rownum += 1
        if rownum>options.batchsize: break
    cluster_viewer.set_model(grid)
    move_to(0)

def move_to(index):
    """Move to the given index in the current view."""
    index = detuple(index)
    cluster_viewer.set_cursor(index)
    cluster_viewer.select_path(index)
    update_info()

def update_info():
    """Update the character information associated with the currently
    selected character."""
    index = cluster_viewer.get_cursor()
    if index is None: return
    index = index[0][0]
    row = grid[index][2]
    rel = row.rel
    if rel!="" and rel is not None:
        rel = " ".join(["%.2f"%float(x) for x in row.rel.split()])
    try:
        info = "%-7d: %s %s (%s) %s %s"%(row.id,row.classes,row.count,rel,row.cost,row.file[-20:])
    except:
        info = "%-7d"%(row.id,)
    info_area.set_text(info)

def get_extended():
    """Get a string from a dialog box, used for extended labels."""
    dialog = gtk.MessageDialog(
        None,
        gtk.DIALOG_MODAL|gtk.DIALOG_DESTROY_WITH_PARENT,
        gtk.MESSAGE_QUESTION,
        gtk.BUTTONS_OK,
        None)
    dialog.set_markup("Transcript:")
    entry = gtk.Entry()
    entry.connect("activate",
                  lambda e,d,r: d.response(r),
                  dialog,gtk.RESPONSE_OK)
    dialog.vbox.pack_end(entry,True,True,0)
    dialog.show_all()
    dialog.run()
    text = entry.get_text()
    dialog.destroy()
    return text

def set_dist(x,s):
    assert type(x)==ndarray,x
    assert type(s[0])==ndarray,s[0]
    minerr = 10000
    for y in s:
        if x.shape!=y.shape: continue
        xt = x
        err,rerr,_ = ocrolib.symdist(xt,y)
        minerr = min(err,minerr)
    return minerr

### toolbar commands

def cmd_similar(*args):
    """Sort by similarity to the selected item."""
    global grid,cluster_viewer
    assert type(grid)==gtk.ListStore
    if 0:
        selected = [grid[i][2].image for i in cluster_viewer.get_selected_items()]
        dists = [set_dist(x[2].image,selected) for x in grid]
    else:
        selected = [docproc.isotropic_rescale(grid[i][2].image) for i in cluster_viewer.get_selected_items()]
        dists = [set_dist(docproc.isotropic_rescale(x[2].image),selected) for x in grid]
    index = array(argsort(dists))
    index = [int(i) for i in index]
    grid.reorder(index)
    cluster_viewer.unselect_all()
    cluster_viewer.scroll_to_path(1,1,0,0)
    return 1

# sorting

def sort_menu_populate(cmd):
    menu = gtk.Menu()
    group = None
    activated = 0
    for sort in ["count","width","height","aspect","area","pixels","components","holes","endpoints","junctions"]:
        item = gtk.RadioMenuItem(group=group,label=sort)
        group = item
        item.connect("toggled",cmd,sort)
        menu.append(item)
    if not activated: 
        item.activate() # activate the last one if there is no default
    return menu

def cmd_sel_sort(item,which):
    global sort_style
    if item.get_active():
        print "sort style",item,which
        sort_style = which

def charprop(image,kind):
    if kind=="holes":
        result = ocrolib.hole_counts(image,1.0)
    elif kind=="components":
        result = ocrolib.component_counts(image,1.0)
    elif kind=="junctions":
        result = ocrolib.junction_counts(image,1.0)
    elif kind=="endpoints":
        result = ocrolib.endpoints_counts(image,1.0)
    else:
        raise Exception("unknown charprop")
    print image.shape,kind,result
    return result

def tighten(image):
    h = sum(image,axis=1)
    hs = arange(len(h))[h>0]
    r0 = amin(h)
    r1 = amax(h)+1
    v = sum(image,axis=0)
    vs = arange(len(v))[v>0]
    c0 = amin(v)
    c1 = amax(v)
    return image[r0:r1,c0:c1]
        
def cmd_sort(*args):
    global grid,cluster_viewer,sort_style
    assert type(grid)==gtk.ListStore
    print "cmd_sort",sort_style
    images = [tighten(x[2].image) for x in grid]
    if sort_style=="width":
        dists = [image.shape[1] for image in images]
    elif sort_style=="height":
        dists = [image.shape[0] for image in images]
    elif sort_style=="aspect":
        dists = [image.shape[1]*1.0/(0.0001+image.shape[0]) for image in images]
    elif sort_style=="area":
        dists = [prod(image.shape) for image in images]
    elif sort_style=="pixels":
        dists = [sum(image) for image in images]
    elif sort_style=="components":
        dists = [charprop(image,"components") for image in images]
    elif sort_style=="holes":
        dists = [charprop(image,"holes") for image in images]
    elif sort_style=="endpoints":
        dists = [charprop(image,"endpoints") for image in images]
    elif sort_style=="junctions":
        dists = [charprop(image,"junctions") for image in images]
    else:
        dists = [-x[2].count for x in grid]
    print dists[:20]
    index = array(argsort(array(dists)))
    index = [int(i) for i in index]
    grid.reorder(index)
    cluster_viewer.unselect_all()
    cluster_viewer.scroll_to_path(1,1,0,0)

# classifying

def classifier_menu(cmd):
    menu = gtk.Menu()
    group = None
    activated = 0
    models = ["None"]
    models += glob.glob("*.cmodel") + glob.glob("*.model") + glob.glob("*.pymodel")
    try:
        modeldir = ocrolib.finddir("models")
        models += glob.glob(modeldir+"/*.cmodel")
        models += glob.glob(modeldir+"/*.model")
        models += glob.glob(modeldir+"/*.pymodel")
    except IOError:
        pass
    for model in models:
        item = gtk.RadioMenuItem(group=group,label=model)
        group = item
        item.connect("toggled",cmd,model)
        menu.append(item)
        if model==None: item.activate()
    return menu

def cmd_sel_class(item,which):
    if item.get_active():
        print "activated class",item,which
        load_classifier(which)

classifier = None

class NoException:
    pass

def load_classifier(which=options.model):
    global classifier
    if which is None:
        classifier = None
        return
    try:
        classifier = ocrolib.load_component(which)
        # classifier.info()
        # classifier.getExtractor().info()
        print "loaded classifier",which,classifier
    except:
        traceback.print_exc()
        print "loading",options.model,"failed"
        classifier = None

# load_classifier()

# confidence 

confidence = None

def load_confidence(which=options.model):
    global confidence
    if which is None:
        confidence = None
        return
    try:
        confidence = ocrolib.load_component(which)
        # confidence.info()
        # confidence.getExtractor().info()
        print "loaded confidence",which,confidence
    except:
        traceback.print_exc()
        print "loading",options.model,"failed"
        confidence = None

# load_confidence()

def cmd_sel_conf(item,which):
    if item.get_active():
        print "activated conf",item,which
        load_confidence(which)

def cmd_nn(*args):
    """Classify with the selected classifier."""
    global grid,cluster_viewer
    assert type(grid)==gtk.ListStore
    count = 0
    images = []
    geometries = []
    selected = list(cluster_viewer.get_selected_items())
    for i in selected:
        row = grid[i]
        pat = row[2]
        image = pat.image
        images.append(image)
        if hasattr(pat,"rel"):
            geometry = docproc.rel_geo_normalize(pat.rel)
        else:
            geometry = None
        geometries.append(geometry)
    results = classifier.coutputs_batch(images,geometries)
    for k in range(len(selected)):
        i = selected[k]
        row = grid[i]
        outputs = results[k]
        outputs = [(x[0],-log(x[1])) for x in outputs]
        outputs.sort(key=lambda x:x[1])
        if len(outputs)<1:
            row[1] = ""
            row[2].cls = ""
        else:
            cls,cost = outputs[0]
            row[1] = cls
            row[2].cls = cls
            db.execute("update %s set cls=? where id=?"%options.table,[row[2].cls,row[2].id])
        if count<10 or count%100==0:
            if len(outputs)>1:
                print count,row[2].cls,outputs[0],type(row[2].cls),row[2].id
            else:
                print count,"no output"
        count += 1
    db.commit()
    return 1

def equivalent(c):
    if c>="A" and c<="Z": return [c,c.lower()]
    if c>="a" and c<="z": return [c,c.upper()]
    if c in ["0","o","O"]: return ["0","o","O"]
    if c in ["1","l","I","|"]: return ["1","l","I","|"]
    return c

def cmd_class(*args):
    """Highlight misclassified samples and sort by classifier confidence."""
    global grid,cluster_viewer,confidence
    assert type(grid)==gtk.ListStore
    cluster_viewer.unselect_all()
    dists = []
    n = len(grid)
    print "parallel classifier",n
    inputs = []
    geometries = []
    for i in range(n):
        row = grid[i]
        image = row[2].image
        inputs.append(image)
        if hasattr(row[2],"rel"):
            geometry = docproc.rel_geo_normalize(row[2].rel)
        else:
            geometry = None
        geometries.append(geometry)
    results = confidence.coutputs_batch(inputs,geometries)
    for i in range(n):
        outputs = results[i]
        outputs = [(x[0],-log(x[1])) for x in outputs]
        outputs.sort(key=lambda x:x[1])
        if len(outputs)<1: continue
        cls,cost = outputs[0]
        if row[2].cls not in equivalent(cls):
            cluster_viewer.select_path(i)
        pcost = 9999
        for cls,cost in outputs:
            if row[2].cls in equivalent(cls):
                pcost = cost
                break
        dists.append(pcost)
        if i<10 or i%1000==0:
            print i,row[2].cls,outputs[0],pcost
    index = argsort(-array(dists))
    index = [int(i) for i in index]
    grid.reorder(index)
    # cluster_viewer.unselect_all()
    cluster_viewer.scroll_to_path(1,1,0,0)
    return 1

def cmd_train(*args):
    return 1

### basic event handlers

def on_comboboxentry1_changed(entry):
    s = entry.get_active_text()
    if s is not None and s!="":
        set_store(s)

def on_iconview1_item_activated(*args):
    update_info()

def on_iconview1_motion_notify_event(widget,event):
    # item = widget.get_item_at_pos(event.x,event.y)
    # if item is not None: move_to(item[0])
    return 0

def on_iconview1_button_press_event(widget,event):
    if event.button==1: 
        update_info()
        return 0
    if event.button==2:
        item = widget.get_item_at_pos(event.x,event.y)
        if item is not None: 
            s = "_"
            item = detuple(item)
            move_to(item)
            row = grid[item]
            row[1] = s
            row[2].cls = s
            db.execute("update %s set cls=? where id=?"%options.table,[row[2].cls,row[2].id])

def on_iconview1_key_press_event(widget,event):
    item = cluster_viewer.get_cursor()
    # print item,event.string
    if event.string in ["\027"]: # ^W
        i = class_selector.get_active()
        if i<=0: return 1
        class_selector.set_active(i-1)
        update_info()
        return 1
    if event.string in ["\032"]: # ^Z
        i = class_selector.get_active()
        if i<0: i = 0
        if i>=len(combolist)-1: return 1  # it's one longer than the number of entries
        class_selector.set_active(i+1)
        update_info()
        return 1
    if event.string=="\022": # ^R = reload
        compute_combolist()
        return 1
    if event.string>=" " or event.string=="\004":
        last = -1
        s = event.string
        if s==".":
            s = get_extended()
        else:
            if s in ["\004"]: s = ""
            elif s in ["`"]: s = "~" # switch ` and ~
            elif s in ["~"]: s = "`"
            elif s in [" "]: s = "_"
        items = cluster_viewer.get_selected_items()
        if items==[]:
            items = [widget.get_cursor()[0][0]]
        for item in items:
            item = detuple(item)
            if item>last: last = item
            row = grid[item]
            row[1] = s
            row[2].cls = s
            db.execute("update %s set cls=? where id=?"%options.table,[row[2].cls,row[2].id])
        db.commit()
        cluster_viewer.unselect_all()
        move_to(last+1)
        return 1
    return 0

def build_toolbar():
    global toolbar
    toolbar = main_widget_tree.get_widget("toolbar")
    toolbar.set_style(gtk.TOOLBAR_BOTH)
    button = gtk.ToolButton(label="Similar")
    button.connect("clicked",cmd_similar)
    toolbar.insert(button,-1)
    
    # button = gtk.ToolButton(label="Freq")
    # button.connect("clicked",cmd_freq)
    # toolbar.insert(button,-1)

    # sort selector
    sort_button = gtk.MenuToolButton(None,"Sort")
    sort_menu = sort_menu_populate(cmd_sel_sort)
    sort_button.set_menu(sort_menu)
    sort_menu.show_all()
    sort_button.show_all()
    sort_button.connect("clicked",cmd_sort)
    toolbar.insert(sort_button,-1)

    # classify selector
    classify_button = gtk.MenuToolButton(None,"Classify")
    classify_menu = classifier_menu(cmd_sel_class)
    classify_button.set_menu(classify_menu)
    classify_menu.show_all()
    classify_button.show_all()
    classify_button.connect("clicked",cmd_nn)
    toolbar.insert(classify_button,-1)

    # confidence selector
    conf_button = gtk.MenuToolButton(None,"Confidence")
    conf_menu = classifier_menu(cmd_sel_conf)
    conf_button.set_menu(conf_menu)
    conf_menu.show_all()
    conf_button.show_all()
    conf_button.connect("clicked",cmd_class)
    toolbar.insert(conf_button,-1)
    
    train_button = gtk.ToolButton(label="Train")
    train_button.connect("clicked",cmd_train)
    toolbar.insert(train_button,-1)

def main():
    global main_widget_tree,class_selector,cluster_viewer,info_area
    gladefile = ocrolib.findfile("gui/ocroold-cedit.glade")
    windowname = "window1" 
    main_widget_tree = gtk.glade.XML(gladefile)
    dic = {
        "on_window1_destroy_event" : gtk.main_quit,
        "on_window1_delete_event" : gtk.main_quit,
        "on_iconview1_key_press_event" : on_iconview1_key_press_event,
        "on_iconview1_button_press_event" : on_iconview1_button_press_event,
        "on_iconview1_item_activated" : on_iconview1_item_activated,
        "on_iconview1_selection_changed" : on_iconview1_item_activated,
        "on_iconview1_motion_notify_event" : on_iconview1_motion_notify_event,
        "on_comboboxentry1_changed" : on_comboboxentry1_changed,
        }
    main_widget_tree.signal_autoconnect(dic)
    window = main_widget_tree.get_widget("window1")
    build_toolbar()

    graphview = main_widget_tree.get_widget("scrolledwindow1") 
    cluster_viewer = main_widget_tree.get_widget("iconview1")
    cluster_viewer.set_selection_mode(gtk.SELECTION_MULTIPLE)
    class_selector = main_widget_tree.get_widget("comboboxentry1")
    info_area = main_widget_tree.get_widget("info")
    cluster_viewer.set_item_width(50)
    cluster_viewer.set_pixbuf_column(0)
    cluster_viewer.set_text_column(1)
    assert cluster_viewer is not None
    graphview.show_all()
    cluster_viewer.show_all()

    compute_combolist()
    # class_selector.set_active(0)

    status = main_widget_tree.get_widget("status")
    main_widget_tree.get_widget("window1").show_all()
    if options.select is not None:
        set_store(options.select)
    gtk.main()

main()
